# Copyright (c) Stanford University, The Regents of the University of
#               California, and others.
#
# All Rights Reserved.
#
# See Copyright-SimVascular.txt for additional details.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject
# to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
# OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import numpy as np
import sys
import vtk
import utils
import io_utils
import marching_cube as m_c

#TO-DO improve compatibility with label ids, line 32
class Images(object):
    def __init__(self, fn):
        self.label = io_utils.read_label_map(fn)
   
    def convert_to_binary(self):
        self.label = utils.convert_vtk_im_to_binary(self.label)
        #self.label = utils.gaussianSmoothVTKImage(self.label, 0.01)

    def resample(self, resolution, mode):
        self.label = utils.vtk_image_resample(self.label, resolution, mode)

    def get_image(self):
        return self.label
    
    def write_image(self,fn):
        io_utils.write_vtk_image(self.label, fn)

    def generate_surface(self, region_id, smooth_iter, band):
        poly = m_c.vtk_marching_cube(self.label, region_id, smooth_iter, band)
        #return m_c.vtk_continuous_marching_cube(self.label, region_id, smooth_iter)
        return poly

class LVImage(Images):
    
    def erase_boundary(self):
        ### this is only needed for a more general left heart model
        from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk
        x, y, z = self.label.GetDimensions()
        pylabel = vtk_to_numpy(self.label.GetPointData().GetScalars()).reshape(z, y, x).transpose(2, 1, 0)
        pylabel = utils.erase_boundary(pylabel, 2, 0)
        self.label.GetPointData().SetScalars(numpy_to_vtk(pylabel.transpose(2, 1, 0).flatten()))

    def process(self, remove_list):
        self.label = utils.vtk_image_resample(self.label, spacing=(1.2, 1.2, 1.2), opt='NN')
        from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk
        pylabel = vtk_to_numpy(self.label.GetPointData().GetScalars())
        pylabel = utils.swap_labels(pylabel)

        #remove myocardium, RV, RA and PA
        for tissue in remove_list:
            pylabel = utils.remove_class(pylabel, tissue, 0)
        self.label.GetPointData().SetScalars(numpy_to_vtk(pylabel))
        # remove small islands
        self.label = utils.extract_largest_connected_region(self.label, 6)
        self.label = utils.extract_largest_connected_region(self.label, 3)
        self.label = utils.extract_largest_connected_region(self.label, 2)
        # remove connections between AA and LA
        self.label = utils.label_dilate_erode(self.label, 6, 3, 8) #6 - AO id, 3 - LV id
        self.label = utils.label_open_close(self.label, 6, 0, size=7)
        self.label = utils.label_open_close(self.label, 0, 6, size=7)
        self.label = utils.label_dilate_erode(self.label, 2, 3, 3) #6 - AO id, 3 - LV id
        self.label = utils.label_open_close(self.label, 2, 0, size=7)
        self.label = utils.label_open_close(self.label, 3, 0, size=7)
        self.label = utils.label_open_close(self.label, 0, 3, size=7)
        self.label = utils.label_open_close(self.label, 0, 2, size=7)
        ids = utils.locate_region_boundary_ids(self.label, 2, 6, size=4.,bg_id=0)
        self.ids = np.vstack((ids, utils.locate_region_boundary_ids(self.label, 6, 2, size=5., bg_id=0)))
        self.label = utils.label_open_close(self.label, 2, 0, size=7)
        self.label = utils.recolor_vtk_pixels_by_ids(self.label, self.ids, 0)
    
    def build_cutter(self, region_id, avoid_id, adjacent_id, FACTOR, op='valve', smooth_iter=50):
        """
        Build cutter for aorta and la

        Args:
            label: original SimpleITK image
            region_id: id of aorta or LA to build cutter
            avoid_id: id of aorta or LA to avoid cutting into
            op: 'valve' or 'tissue', option for normal direction
        """
        cut_Im = vtk.vtkImageData()
        cut_Im.DeepCopy(self.label)
        #locate centroid of mitral plane or aortic plane
        pts = utils.locateRegionBoundary(cut_Im, adjacent_id, region_id, size=2.)
        ctr_valve = np.mean(pts, axis=0)
        
        from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk
        vtkpts = vtk.vtkPoints()
        vtkpts.SetData(numpy_to_vtk(pts))
        #centroid of left atrium or aorta
        ctr = utils.get_centroid(cut_Im, region_id)
        #center and nrm of the cutting plane
        length = np.linalg.norm(ctr-ctr_valve)
        nrm_tissue = (ctr - ctr_valve)/length
        nrm_valve_plane = utils.fit_plane_normal(pts)
        #check normal direction
        if op=='valve':
            #nrm = nrm_valve_plane
            #if np.dot(nrm_tissue, nrm_valve_plane)<0:
            #    nrm =  -1 *nrm
            nrm = nrm_tissue
        elif op=='tissue':
            nrm = nrm_tissue
            #nrm = nrm_valve_plane
            #if np.dot(nrm_tissue, nrm_valve_plane)<0:
            #    nrm =  -1 *nrm
        else:
            raise ValueError("Incorrect option")
        ori = ctr_valve + FACTOR * nrm/np.linalg.norm(nrm)
        #dilate by a little bit
        cut_Im = utils.label_dilate_erode(utils.recolor_vtk_pixels_by_plane(cut_Im, ori, -1.*nrm, 10, avoid_id), region_id, 0, 8.)
        cut_Im = utils.label_dilate_erode(cut_Im, avoid_id, region_id, 2)
        
        # marching cube
        cutter = m_c.vtk_marching_cube(cut_Im, region_id, 20, 0.05)
        return cutter, (ctr_valve, nrm)

